{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorials on Entity Embedding\n",
    "In this short tutorial, we demonstrate how to contruct entity embedding using vincinity information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPS = 0.00001"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Note\n",
    "This is the final implementation, but not the process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EntityEmbeddingLayer(nn.Module):\n",
    "    def __init__(self, num_level, emdedding_dim, centroid):\n",
    "        super(EntityEmbeddingLayer, self).__init__()\n",
    "        self.embedding = nn.Embedding(num_level, embedding_dim)\n",
    "        self.centroid = torch.tensor(centroid).detach_().unsqueeze(1)\n",
    "    def forward(self, x): \n",
    "        \"\"\"\n",
    "        x must be batch_size times 1\n",
    "        \"\"\"\n",
    "        x = x.unsqueeze(1)\n",
    "        d = 1.0/((x-self.centroid).abs()+EPS)\n",
    "        w = F.softmax(d.squeeze(2), 1)\n",
    "        v = torch.mm(w, self.embedding.weight)\n",
    "        return v\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experimentation\n",
    "This is the part where you experiment with the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_level = 10 # When configure dimension, do not use the same dimension for testing\n",
    "embedding_dim = 5 \n",
    "embedding = nn.Embedding(num_level, embedding_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 4\n",
    "x = torch.randn(batch_size, 1)\n",
    "centroid = torch.randn(num_level, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 5])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding.weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 1, 1])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 10, 1])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "centroid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = x.unsqueeze(1)\n",
    "centroid = centroid.unsqueeze(0) # This is not needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 10, 1])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(x- centroid).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 1.0/((x-centroid).abs()+EPS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = F.softmax(d.squeeze(2), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 10])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(4.)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "v = torch.mm(w, embedding.weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 5])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  \"\"\"\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(batch_size, 1)\n",
    "centroid = torch.randn(num_level) # This differs from the above definition\n",
    "entity_embedding = EntityEmbeddingLayer(num_level, embedding_dim, centroid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2952, -0.1353, -0.3040,  0.1783,  0.0248],\n",
       "        [-0.6496, -2.2369, -1.2419,  1.3052, -1.2015],\n",
       "        [ 1.0792,  1.3434,  0.5521,  0.6170,  0.6364],\n",
       "        [ 0.2842, -0.1270, -0.2875,  0.1578,  0.0382]], grad_fn=<MmBackward>)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entity_embedding(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
